import os
import sys
import time
import signal
import subprocess
import numpy as np
import random
from copy import deepcopy

import gymnasium as gym
from gymnasium import spaces
from pettingzoo import ParallelEnv

import carla
import py_trees
import threading
from omegaconf import OmegaConf

from agent import AgentWrapper, HumanAgent, BehaviorAgent, CommAgent, CommOnlyAgent, ControlOnlyAgent, CoopernautReceiverAgent, CoopernautSenderAgent
from srunner.scenariomanager.carla_data_provider import CarlaDataProvider
from srunner.scenariomanager.timer import GameTime
from srunner.tools.route_parser import RouteParser

from envs.scenarios.route_scenario import RouteScenario
from envs.scenarios.multiagent_scenario import MultiAgentScenario
from envs.utils.evaluator import Evaluator
from envs.utils.video_recorder import VideoRecorder
from envs.utils.hud import HUD
from comm.central_receiver import CentralReceiver
from typing import Dict, List, Tuple, Any, Union, Iterable, Optional, TypeVar

class MultiAgentEnv(ParallelEnv):
    """
    Multi-agent environments for multi-agent navigation tasks.
    It steps every live agent at once.
    To use this base multi-agent environtment:
    1. Initialize the environment,
    2. Call env.reset() to get an initial observation,
    3. Call env.step(actions) to apply actions to all agents, and receive the next obs and reward,
    4. Clean up the resources env.close() after everything is done.
    """
    def __init__(self,
                 config,
                 ):
        """
        Load Carla server and client
        Read the scenario config file
        """
        self._config = config
        self.seed = int(config.seed)                     # Random seed
        self.timeout = float(config.timeout)             # in Seconds
        self.frame_rate = float(20.0)                    # in Hz
        self.port = int(config.port)                     # Carla server port

        # Set up server
        self.server = None
        self._init_server(int(self._config.gpu_id),
                          int(self._config.num_workers),
                          int(self._config.port))

        # Set up client
        self.client = carla.Client('localhost', self.port)
        self.client.set_timeout(self.timeout)
        # Get scenario config file
        self._scenario_config = OmegaConf.load(config.env_config.scenario_config)
        self.ego_vehicles = []
        self.agents = []
        self.focal_agents = []

        # Load world
        self.world = None
        self.scenario = None

        # Set up Gymnasium variables
        self.action_spaces = {}
        self.observation_spaces = {}
        self.max_episode_length = 20                 # in Seconds

        # Loggers
        self.hud = None
        self.video_recorder = None
        self.spectator = None
        self.central_receiver = CentralReceiver()
        self.logdir = self._config.logdir

    def _init_server(self, gpu_id=0, num_workers=1, port=2000):
        return
        if self._config.clientonly:
            return
        # Start CARLA instances
        print("Start and connecting to carla server")
        self.server=subprocess.Popen(["./envs/utils/launch_carla.sh",   # shell script
                                      " {}".format(str(gpu_id)),        # GPU id
                                      " {}".format(str(num_workers)),   # num workers
                                      " {}".format(str(port)),          # port
                                      ] )
        time.sleep(10)
    
    def _load_world(self, town='Town12'):
        """
        Load a new CARLA world and provide data to CarlaDataProvider
        """
        if self.world is None:
            self.world = self.client.load_world(town, reset_settings=False)
        else:
            self.world = self.client.get_world()
        
        # Set world in sychronomous mode
        settings = self.world.get_settings()
        settings.synchronous_mode = True
        settings.fixed_delta_seconds = 1.0 / self.frame_rate
        self.world.apply_settings(settings)
        print("***** \nCarla world connected\n*****")

        # Set up traffic manager
        CarlaDataProvider.set_traffic_manager_port(int(self._config.trafficManagerPort))
        self.traffic_manager = self.client.get_trafficmanager(int(self._config.trafficManagerPort))
        self.traffic_manager.set_synchronous_mode(True)
        self.traffic_manager.set_random_device_seed(self.seed)

        # Register world in carla data provider
        CarlaDataProvider.set_client(self.client)
        CarlaDataProvider.set_world(self.world)

        self.world.tick()
    
    def _load_scenario(self):
        """
        Load a new scenario
        """
        # Initialize scenario
        scenario = MultiAgentScenario(world=self.world,
                                      config=self._scenario_config,
                                      debug_mode=self._config.debug,
                                      terminate_on_failure=False,
                                     )
        self.scenario = scenario
        self.scenario_tree = scenario.scenario_tree
        self.ego_vehicles = scenario.ego_vehicles
        self.other_actors = scenario.other_actors

    def _setup_agents(self):
        """
        Set up agents for ego vehicles
        This includes sensors, control wrappers, etc.
        """
        self.possible_agents = []
        self.agents = []
        for i, agent_config in enumerate(self._scenario_config.vehicles + self._scenario_config.other_actors):
            if i < len(self.ego_vehicles):
                # Set up ego vehicles
                vehicle = self.ego_vehicles[i]
            else:
                # Set up other actors
                vehicle = self.other_actors[i - len(self.ego_vehicles)]
            # Set up agent according to the config file
            path_to_conf_file = None
            if 'human' in agent_config.agent_type:
                agent = HumanAgent(path_to_conf_file)
                agent.setup(agent_config.name, path_to_conf_file)
            elif 'comm_agent' in agent_config.agent_type:
                agent = CommAgent(vehicle, agent_config)
            elif 'control_only_agent' in agent_config.agent_type:
                agent = ControlOnlyAgent(vehicle, agent_config)
            elif 'comm_only_agent' in agent_config.agent_type:
                agent = CommOnlyAgent(vehicle, agent_config)
            elif 'coopernaut_receiver_agent' in agent_config.agent_type:
                agent = CoopernautReceiverAgent(vehicle, agent_config)
            else:
                agent = BehaviorAgent(vehicle, agent_config, behavior='aggressive')
            # Universal wrapper for env to interface with the policy
            agent = AgentWrapper(agent, vehicle, agent_config)
            if 'behavior_agent' not in agent_config.agent_type:
                agent.setup_sensors(vehicle, self._config.debug)
            agent.reset()
            self.agents.append(agent)
            self.possible_agents.append(agent)
            if agent.is_focal:
                self.focal_agents.append(agent)
        self._update_background_agents()

    def _update_background_agents(self):
        """
        Set up agents for background vehicles
        By default, assign CommOnlyAgents or DummyAgents to background vehicles
        """
        actor_iterator = CarlaDataProvider.get_actors()
        default_agent_config = deepcopy(self._scenario_config.vehicles[0])
        for vehicle_id, vehicle in actor_iterator:
            if vehicle not in (self.ego_vehicles + self.other_actors):
                default_agent_config.name = 'background'
                default_agent_config.agent_type = 'noop_agent'
                #default_agent_config.agent_type = 'coopernaut_sender_agent'
                default_agent_config.task = 'share information with others'
                default_agent_config.is_focal = False
                agent = CommOnlyAgent(vehicle, default_agent_config)
                #agent = CoopernautSenderAgent(vehicle, default_agent_config)
                agent = AgentWrapper(agent, vehicle, default_agent_config)
                agent.setup_sensors(vehicle, self._config.debug)
                agent.reset()
                # self.agents.append(agent)
                # self.possible_agents.append(agent)

    def reset(
        self,
        seed = None,
        options = None,
        ):
        """
        Reset the environment according to the scenario config file
        """
        self._cleanup()

        self.seed = seed or self.seed
        CarlaDataProvider._random_seed = self.seed

        # Load the scenario again
        self._load_world(town=self._scenario_config.map)
        self._load_scenario()
        self.world.tick()

        # Reset game and env time
        self.step_count = 0
        GameTime.restart()
        self._running = True

        # Get time stamp
        snapshot = self.world.get_snapshot()
        self.timestamp = snapshot.timestamp
        self.gametimestamp = GameTime.get_time()
        
        # Set up agents
        self._setup_agents()

        # Get observations
        observations = self._get_observations()

        # Set up spectator
        self._setup_spectator()

        # Record video
        if self._config.record_video:
            self.video_recorder = VideoRecorder(self.spectator_transform)
            self.video_recorder.reset()

       # Evaluator for the scenario
        self.evaluator = Evaluator(self)

        # Display simulation information
        self.central_receiver=CentralReceiver()
        self.hud = HUD(width=1220, height=600, logdir=self.logdir)
        self.hud.update(self.video_recorder.current_frame,
                        self.world,
                        self.timestamp,
                        self.central_receiver.get_lang_data())

        info = {agent._agent_id:{} for agent in self.agents}

        # Increment seed
        self.seed += 1

        return observations, info

    def step(self, actions):
        """
        actions should be in the form of Dict[AgentID, Action]
        """
        observations = {}
        terminated = {agent._agent_id: False for agent in self.agents}
        truncated = {agent._agent_id: False for agent in self.agents}
        rewards = {agent._agent_id: 0 for agent in self.agents}
        infos = {agent._agent_id: {} for agent in self.agents}

        # Truncate
        if self.step_count >= self.max_episode_length * self.frame_rate:
            for agent_id in self.agents:
                truncated[agent_id] = True
            truncated['__all__'] = True
        else: truncated['__all__'] = False
        # Get time stamp
        snapshot = self.world.get_snapshot()
        self.timestamp = snapshot.timestamp
        self.gametimestamp = GameTime.get_time()

        # Apply actions to corresponding vehicles
        self._apply_control(actions)

        # Tick the scenario
        self.scenario_tree.tick_once()

        # Tick the world
        self.world.tick()

        # Update game time and actor information
        CarlaDataProvider.on_carla_tick()
        GameTime.on_carla_tick(self.timestamp)

        # Get next observations
        observations = self._get_observations()

        # Record video
        if self._config.record_video:
            self.video_recorder.record_frame()
        vehicle_actions = {agent._vehicle_id: actions[agent._agent_id] if agent._agent_id in actions else {"command":"None"} for agent in self.agents}
        self.hud.update(self.video_recorder.current_frame,
                        self.world,
                        self.timestamp,
                        self.central_receiver.get_lang_data(),
                        vehicle_actions)

        # Get rewards
        rewards = self._get_rewards()

        # Terminate
        terminated = self._update_terminations()

        # Get Evaluator
        if terminated['__all__'] or truncated['__all__']:
            feedback = self.evaluator.get_feedback()
            print("Feedback:", feedback)
            for agent in self.agents:
                infos[agent._agent_id]["feedback"] = self.evaluator.get_agent_feedback(agent)
                infos[agent._agent_id]["reward"] = self.evaluator.get_agent_episode_reward(agent)
                infos[agent._agent_id]["collision"] = self.evaluator.get_agent_collision(agent)
                infos[agent._agent_id]["route_completion"] = self.evaluator.get_agent_route_completion(agent)
                infos[agent._agent_id]["completion_time"] = self.step_count / self.frame_rate
                infos[agent._agent_id]["timeout"] = 1 if (infos[agent._agent_id]["route_completion"] > 0 and truncated["__all__"]) else 0
            infos['__all__'] = self.evaluator.get_language_feedback()
            for k, v in infos.items():
                print(k, v)
            if self._config.record_video:
                self.video_recorder.save()

        # Update env time
        self.step_count += 1
        return observations, rewards, terminated, truncated, infos

    def _update_terminations(self):
        terminated = {agent._agent_id: False for agent in self.agents}
        terminated['__all__'] = False

        if self.scenario_tree.status != py_trees.common.Status.RUNNING:
            self._running = False
        if not self._running:
            print("Scenario terminated")
            for agent in self.agents:
                terminated[agent._agent_id] = True
            terminated['__all__'] = True
        else:
            # Check if all agents are terminated
            all_terminated = True
            for agent in self.agents:
                # Check if task is completed or collision occurred for an agent
                agent_termination = self.evaluator.get_agent_termination(agent)
                if agent_termination and agent._vehicle.is_alive and agent in self.focal_agents:
                    print("Agent {} terminated".format(agent._agent_id))
                    agent._vehicle.destroy()
                    agent.cleanup()
                terminated[agent._agent_id] = agent_termination
                agent.is_alive = not agent_termination
                if agent.is_focal and not agent_termination:
                    all_terminated = False
            terminated['__all__'] = all_terminated
        print("terminated", terminated)
        return terminated

    def render(self, mode='human'):
        pass

    def _setup_spectator(self):
        self.spectator = self.world.get_spectator()
        spectator_transform = self._scenario_config.get('spectator_transform', None)
        if spectator_transform is None:
            spectator_transform = self.ego_vehicles[0].get_transform()
        else:
            transform = carla.Transform()
            transform.location.x = spectator_transform.x
            transform.location.y = spectator_transform.y
            transform.rotation.yaw = spectator_transform.yaw
            spectator_transform = transform
        spectator_transform.location.z += 50
        spectator_transform.rotation.pitch -= 90
        self.spectator_transform = spectator_transform
        self.spectator.set_transform(spectator_transform)

    def _apply_control(self, actions):
        """
        Apply control to all ego vehicles
        """
        for agent in self.agents:
            agent_id = agent._agent_id
            if agent_id in actions and agent.is_alive:
                agent.apply_control(actions[agent_id])

    def _get_observations(self):
        """
        Get observations for all agents
        """
        observation = {}
        for agent in self.agents:
            if agent.is_alive:
                agent_id = agent._agent_id
                observation[agent_id] = agent.observe()
        return observation

    def _get_rewards(self):
        """
        Get rewards for all agents
        """
        rewards = {}
        for agent in self.agents:
            agent_id = agent._agent_id
            if agent.is_alive:
                rewards[agent_id] = self.evaluator.get_agent_reward(agent)
            else:
                rewards[agent_id] = 0
        return rewards

    def _cleanup(self):
        """
        Clean up resources, agents, vehicles, scenario, Carla Data provider
        """
        """
        A proper temination of a scenario
        """
        # Set world to async mode to avoid waiting for signals
        if self.world is not None:
            settings = self.world.get_settings()
            settings.synchronous_mode = False
            settings.fixed_delta_seconds = None
            self.world.apply_settings(settings)
            self.client.get_trafficmanager(int(self._config.trafficManagerPort)).set_synchronous_mode(False)
        # Terminate the scenario and clean
        if self.scenario is not None:
            self.scenario.terminate()
        if self.agents is not None:
            for agent in self.agents:
                agent.cleanup()
                try:agent._agent.destroy()
                except:pass
            del self.agents
            self.agents = []
        # Clean up video recorder
        if self.video_recorder is not None:
            self.video_recorder.destroy()
        # Clean up and remove all entries of data provider
        for i, _ in enumerate(self.ego_vehicles):
            if self.ego_vehicles[i]:
                if not self._config.waitForEgo and self.ego_vehicles[i] is not None and self.ego_vehicles[i].is_alive:
                    print("Destroying ego vehicle {}".format(self.ego_vehicles[i].id))
                    self.ego_vehicles[i].destroy()
                self.ego_vehicles[i] = None
        self.ego_vehicles = []
        self.agents = []
        self.focal_agents = []
        self.possible_agents = []
        CarlaDataProvider.cleanup()


    def close(self):
        """
        Clean up resources, agents, vehicles, scenario, Carla Data provider
        """
        self._cleanup()
        if self.server is not None:
            self.server.terminate()

if __name__ == "__main__":
    from envs.utils.config import get_env_args
    args = get_env_args()
    env = MultiAgentEnv(args=args)
    num_episodes = 1
    for episode in range(num_episodes):
        print(f"""Episode {episode}""")
        obs, info = env.reset()
        done = False
        while not done:
            action = {}
            for agent in env.agents:
                if agent._agent_type == 'comm_agent':
                    action.update({agent._agent_id: agent._agent.run_step({'control':'', 'message':'haha'})})
                elif agent._agent_type == 'behavior_agent':
                    action.update({agent._agent_id: agent._agent.run_step()})
                elif agent._agent_type == 'human':
                    action.update({agent._agent_id: agent._agent.run_step(obs[agent._agent_id], env.gametimestamp)})
            print(obs)
            obs, rew, terminated, truncated, info = env.step(action)
            done = truncated['__all__'] or terminated['__all__']
    #env.render()
    env.close()
    print("Done!")
